{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `torch.nn.Embedding`\n",
    "\n",
    "> ‘我一直对Embedding有一层抽象的，模糊的认识’\n",
    "\n",
    "参考[我最爱的b站up主的内容](https://www.bilibili.com/video/BV1wm4y187Cr/?spm_id_from=333.337.search-card.all.click&vd_source=32f9de072b771f1cd307ca15ecf84087)\n",
    "\n",
    "## embedding的基础概念\n",
    "\n",
    "`embedding`是将词向量中的词映射为固定长度的词向量的技术，可以将one_hot出来的高维度的稀疏的向量转化成低维的连续的向量\n",
    "\n",
    "![直观显示词与词之间的关系](../assets/lecture-pic/Embedding1.png)\n",
    "\n",
    "## 首先明白embedding的计算过程\n",
    "\n",
    "- embedding module 的前向过程是一个索引(查表)的过程\n",
    "    - 表的形式是一个matrix （也即 embedding.weight,learnabel parameters）\n",
    "        - matrix.shape:(v,h)\n",
    "            - v:vocabulary size\n",
    "            - h:hidden dimension\n",
    "\n",
    "    - 具体的索引的过程，是通过onehot+矩阵乘法的形式实现的\n",
    "    - input.shape:(b,s)\n",
    "        - b: batch size\n",
    "        - s: seq len \n",
    "    - embedding(input)=>(b,s,h)\n",
    "    - **这其中关键的问题就是(b,s)和(v,h)怎么变成了(b,s,h)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.4682, -1.2143,  0.1752],\n",
       "         [ 1.5085,  0.4936,  0.3845],\n",
       "         [-1.1064,  1.0143,  0.4442],\n",
       "         [ 0.6037,  0.6854,  0.3562]],\n",
       "\n",
       "        [[-1.1064,  1.0143,  0.4442],\n",
       "         [ 1.0134,  0.2836, -0.6358],\n",
       "         [ 1.5085,  0.4936,  0.3845],\n",
       "         [ 1.5759, -0.5384, -0.0649]]], grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "embedding = nn.Embedding(10, 3)\n",
    "# a batch of 2 samples of 4 indices each\n",
    "input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])\n",
    "embedding(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One-Hot 矩阵乘法\n",
    "\n",
    "目前`one_hot`可以很方便地在`torch.nn.functional`中进行调用，对于一个[batchsize,seqlength]的tensor，one_hot向量可以十分方便的将其转化为[batchsize,seqlength,numclasses],此时，再与[numclasses,h]进行相乘，从而得到最终的[b,s,v]\n",
    "\n",
    "## 参数padding_idx\n",
    "\n",
    "这个参数的作用是指定某个位置的梯度不进行更新，但是为什么不进行更新，以及在哪个位置不进行更新我还没搞明白...."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 1.0000,  1.0000,  1.0000],\n",
       "        [-0.0073,  0.8613, -0.4185],\n",
       "        [-0.1206, -0.8382,  0.4391]], requires_grad=True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# example with padding_idx\n",
    "embedding = nn.Embedding(10, 3, padding_idx=0)\n",
    "input = torch.LongTensor([[0, 2, 0, 5]])\n",
    "embedding(input)\n",
    "# example of changing `pad` vector\n",
    "padding_idx = 0\n",
    "embedding = nn.Embedding(3, 3, padding_idx=padding_idx)\n",
    "embedding.weight\n",
    "with torch.no_grad():\n",
    "     embedding.weight[padding_idx] = torch.ones(3)\n",
    "embedding.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 关于max_norm\n",
    "\n",
    "这个参数用于设置输出和权重参数是否经过了正则化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
